{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow IO Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qFdPvlXBOdUN"
      },
      "source": [
        "# End to end example for BigQuery TensorFlow reader"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/io/tutorials/bigquery\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/io/blob/master/docs/tutorials/bigquery.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/io/blob/master/docs/tutorials/bigquery.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "      <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/io/docs/tutorials/bigquery.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xHxb-dlhMIzW"
      },
      "source": [
        "## Overview\n",
        "\n",
        "This tutorial shows how to use [BigQuery TensorFlow reader](https://github.com/tensorflow/io/tree/master/tensorflow_io/bigquery) for training neural network using the Keras sequential API."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WodUv8O1VKmr"
      },
      "source": [
        "### Dataset\n",
        "\n",
        "This tutorial uses the [United States Census Income\n",
        "Dataset](https://archive.ics.uci.edu/ml/datasets/census+income) provided by the\n",
        "[UC Irvine Machine Learning\n",
        "Repository](https://archive.ics.uci.edu/ml/index.php). This dataset contains\n",
        "information about people from a 1994 Census database, including age, education,\n",
        "marital status, occupation, and whether they make more than $50,000 a year."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MUXex9ctTuDB"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4YsfgDMZW5g6"
      },
      "source": [
        "Set up your GCP project\n",
        "\n",
        "**The following steps are required, regardless of your notebook environment.**\n",
        "\n",
        "1. [Select or create a GCP project.](https://console.cloud.google.com/cloud-resource-manager)\n",
        "2. [Make sure that billing is enabled for your project.](https://cloud.google.com/billing/docs/how-to/modify-project)\n",
        "3. [Enable the BigQuery Storage API](https://cloud.google.com/bigquery/docs/reference/storage/#enabling_the_api)\n",
        "4. Enter your project ID in the cell below. Then run the  cell to make sure the\n",
        "Cloud SDK uses the right project for all the commands in this notebook.\n",
        "\n",
        "Note: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "upgCc3gXybsA"
      },
      "source": [
        "Install required Packages, and restart runtime"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QeBQuayhuvhg"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  # Use the Colab's preinstalled TensorFlow 2.x\n",
        "  %tensorflow_version 2.x \n",
        "except:\n",
        "  pass"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tu01THzWcE-J"
      },
      "outputs": [],
      "source": [
        "!pip install fastavro\n",
        "!pip install tensorflow-io==0.9.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YUj0878jPyz7"
      },
      "outputs": [],
      "source": [
        "!pip install google-cloud-bigquery-storage"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yZmI7l_GykcW"
      },
      "source": [
        "Authenticate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tUo3GJNxxZbc"
      },
      "outputs": [],
      "source": [
        "from google.colab import auth\n",
        "auth.authenticate_user()\n",
        "print('Authenticated')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "czF1KlC6y8fB"
      },
      "source": [
        "Set your PROJECT ID"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fbQR-bT_xgba"
      },
      "outputs": [],
      "source": [
        "PROJECT_ID = \"<YOUR PROJECT>\" #@param {type:\"string\"}\n",
        "! gcloud config set project $PROJECT_ID\n",
        "%env GCLOUD_PROJECT=$PROJECT_ID"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hOqN91M4y_9Y"
      },
      "source": [
        "Import Python libraries, define constants"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G3p2ICG80Z1t"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import, division, print_function, unicode_literals\n",
        "\n",
        "import os\n",
        "from six.moves import urllib\n",
        "import tempfile\n",
        "\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import tensorflow as tf\n",
        "\n",
        "from google.cloud import bigquery\n",
        "from google.api_core.exceptions import GoogleAPIError\n",
        "\n",
        "LOCATION = 'us'\n",
        "\n",
        "# Storage directory\n",
        "DATA_DIR = os.path.join(tempfile.gettempdir(), 'census_data')\n",
        "\n",
        "# Download options.\n",
        "DATA_URL = 'https://storage.googleapis.com/cloud-samples-data/ml-engine/census/data'\n",
        "TRAINING_FILE = 'adult.data.csv'\n",
        "EVAL_FILE = 'adult.test.csv'\n",
        "TRAINING_URL = '%s/%s' % (DATA_URL, TRAINING_FILE)\n",
        "EVAL_URL = '%s/%s' % (DATA_URL, EVAL_FILE)\n",
        "\n",
        "DATASET_ID = 'census_dataset'\n",
        "TRAINING_TABLE_ID = 'census_training_table'\n",
        "EVAL_TABLE_ID = 'census_eval_table'\n",
        "\n",
        "CSV_SCHEMA = [\n",
        "      bigquery.SchemaField(\"age\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"workclass\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"fnlwgt\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"education\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"education_num\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"marital_status\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"occupation\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"relationship\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"race\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"gender\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"capital_gain\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"capital_loss\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"hours_per_week\", \"FLOAT64\"),\n",
        "      bigquery.SchemaField(\"native_country\", \"STRING\"),\n",
        "      bigquery.SchemaField(\"income_bracket\", \"STRING\"),\n",
        "  ]\n",
        "\n",
        "UNUSED_COLUMNS = [\"fnlwgt\", \"education_num\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ooBfVnd-Xxd-"
      },
      "source": [
        "## Import census data into BigQuery"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0qt6wUD_2XFT"
      },
      "source": [
        "Define helper methods to load data into BigQuery"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7mMI7uW_2vP5"
      },
      "outputs": [],
      "source": [
        "def create_bigquery_dataset_if_necessary(dataset_id):\n",
        "  # Construct a full Dataset object to send to the API.\n",
        "  client = bigquery.Client(project=PROJECT_ID)\n",
        "  dataset = bigquery.Dataset(bigquery.dataset.DatasetReference(PROJECT_ID, dataset_id))\n",
        "  dataset.location = LOCATION\n",
        "\n",
        "  try:\n",
        "    dataset = client.create_dataset(dataset)  # API request\n",
        "    return True\n",
        "  except GoogleAPIError as err:\n",
        "    if err.code != 409: # http_client.CONFLICT\n",
        "      raise\n",
        "  return False\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y3Cr-DEfwNhK"
      },
      "outputs": [],
      "source": [
        "def load_data_into_bigquery(url, table_id):\n",
        "  create_bigquery_dataset_if_necessary(DATASET_ID)\n",
        "  client = bigquery.Client(project=PROJECT_ID)\n",
        "  dataset_ref = client.dataset(DATASET_ID)\n",
        "  table_ref = dataset_ref.table(table_id)\n",
        "  job_config = bigquery.LoadJobConfig()\n",
        "  job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE\n",
        "  job_config.source_format = bigquery.SourceFormat.CSV\n",
        "  job_config.schema = CSV_SCHEMA\n",
        "\n",
        "  load_job = client.load_table_from_uri(\n",
        "      url, table_ref, job_config=job_config\n",
        "  )\n",
        "  print(\"Starting job {}\".format(load_job.job_id))\n",
        "\n",
        "  load_job.result()  # Waits for table load to complete.\n",
        "  print(\"Job finished.\")\n",
        "\n",
        "  destination_table = client.get_table(table_ref)\n",
        "  print(\"Loaded {} rows.\".format(destination_table.num_rows))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qSA0RIAZZEFZ"
      },
      "source": [
        "Load Census data in BigQuery."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wFZcK03-YDm4"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Starting job 2ceffef8-e6e4-44bb-9e86-3d97b0501187\n",
            "Job finished.\n",
            "Loaded 32561 rows.\n",
            "Starting job bf66f1b3-2506-408b-9009-c19f4ae9f58a\n",
            "Job finished.\n",
            "Loaded 16278 rows.\n"
          ]
        }
      ],
      "source": [
        "load_data_into_bigquery(TRAINING_URL, TRAINING_TABLE_ID)\n",
        "load_data_into_bigquery(EVAL_URL, EVAL_TABLE_ID)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KpUVI8IR2iXH"
      },
      "source": [
        "Confirm that data was imported\n",
        "\n",
        "TODO: replace \\<YOUR PROJECT\\> with your PROJECT_ID\n",
        "\n",
        "Note: --use_bqstorage_api will get data using BigQueryStorage API and will make sure that you are authorized to use it. Make sure that it is enabled for your project: https://cloud.google.com/bigquery/docs/reference/storage/#enabling_the_api\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CVy3UkDgx2zi"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>age</th>\n",
              "      <th>workclass</th>\n",
              "      <th>fnlwgt</th>\n",
              "      <th>education</th>\n",
              "      <th>education_num</th>\n",
              "      <th>marital_status</th>\n",
              "      <th>occupation</th>\n",
              "      <th>relationship</th>\n",
              "      <th>race</th>\n",
              "      <th>gender</th>\n",
              "      <th>capital_gain</th>\n",
              "      <th>capital_loss</th>\n",
              "      <th>hours_per_week</th>\n",
              "      <th>native_country</th>\n",
              "      <th>income_bracket</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>39.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>297847.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Other-service</td>\n",
              "      <td>Wife</td>\n",
              "      <td>Black</td>\n",
              "      <td>Female</td>\n",
              "      <td>3411.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>34.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&lt;=50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>72.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>74141.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Exec-managerial</td>\n",
              "      <td>Wife</td>\n",
              "      <td>Asian-Pac-Islander</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>48.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&gt;50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>45.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>178215.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Machine-op-inspct</td>\n",
              "      <td>Wife</td>\n",
              "      <td>White</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&gt;50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>31.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>86958.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Exec-managerial</td>\n",
              "      <td>Wife</td>\n",
              "      <td>White</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>40.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&lt;=50K</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>55.0</td>\n",
              "      <td>Private</td>\n",
              "      <td>176012.0</td>\n",
              "      <td>9th</td>\n",
              "      <td>5.0</td>\n",
              "      <td>Married-civ-spouse</td>\n",
              "      <td>Tech-support</td>\n",
              "      <td>Wife</td>\n",
              "      <td>White</td>\n",
              "      <td>Female</td>\n",
              "      <td>0.0</td>\n",
              "      <td>0.0</td>\n",
              "      <td>23.0</td>\n",
              "      <td>United-States</td>\n",
              "      <td>&lt;=50K</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "    age workclass    fnlwgt  ... hours_per_week  native_country income_bracket\n",
              "0  39.0   Private  297847.0  ...           34.0   United-States          <=50K\n",
              "1  72.0   Private   74141.0  ...           48.0   United-States           >50K\n",
              "2  45.0   Private  178215.0  ...           40.0   United-States           >50K\n",
              "3  31.0   Private   86958.0  ...           40.0   United-States          <=50K\n",
              "4  55.0   Private  176012.0  ...           23.0   United-States          <=50K\n",
              "\n",
              "[5 rows x 15 columns]"
            ]
          },
          "execution_count": 10,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "%%bigquery --use_bqstorage_api\n",
        "SELECT * FROM `<YOUR PROJECT>.census_dataset.census_training_table` LIMIT 5"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tOu-pCksYTtE"
      },
      "source": [
        "##Load census data in TensorFlow DataSet using BigQuery reader"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6_Gm8Mzh62yF"
      },
      "source": [
        "Read and transform cesnus data from BigQuery into TensorFlow DataSet"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NgPd9w5m06In"
      },
      "outputs": [],
      "source": [
        "from tensorflow.python.framework import ops\n",
        "from tensorflow.python.framework import dtypes\n",
        "from tensorflow_io.bigquery import BigQueryClient\n",
        "from tensorflow_io.bigquery import BigQueryReadSession\n",
        "  \n",
        "def transofrom_row(row_dict):\n",
        "  # Trim all string tensors\n",
        "  trimmed_dict = { column:\n",
        "                  (tf.strings.strip(tensor) if tensor.dtype == 'string' else tensor) \n",
        "                  for (column,tensor) in row_dict.items()\n",
        "                  }\n",
        "  # Extract feature column\n",
        "  income_bracket = trimmed_dict.pop('income_bracket')\n",
        "  # Convert feature column to 0.0/1.0\n",
        "  income_bracket_float = tf.cond(tf.equal(tf.strings.strip(income_bracket), '>50K'), \n",
        "                 lambda: tf.constant(1.0), \n",
        "                 lambda: tf.constant(0.0))\n",
        "  return (trimmed_dict, income_bracket_float)\n",
        "\n",
        "def read_bigquery(table_name):\n",
        "  tensorflow_io_bigquery_client = BigQueryClient()\n",
        "  read_session = tensorflow_io_bigquery_client.read_session(\n",
        "      \"projects/\" + PROJECT_ID,\n",
        "      PROJECT_ID, table_name, DATASET_ID,\n",
        "      list(field.name for field in CSV_SCHEMA \n",
        "           if not field.name in UNUSED_COLUMNS),\n",
        "      list(dtypes.double if field.field_type == 'FLOAT64' \n",
        "           else dtypes.string for field in CSV_SCHEMA\n",
        "           if not field.name in UNUSED_COLUMNS),\n",
        "      requested_streams=2)\n",
        "  \n",
        "  dataset = read_session.parallel_read_rows()\n",
        "  transformed_ds = dataset.map (transofrom_row)\n",
        "  return transformed_ds\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4_NlkxZt1rwR"
      },
      "outputs": [],
      "source": [
        "BATCH_SIZE = 32\n",
        "\n",
        "training_ds = read_bigquery(TRAINING_TABLE_ID).shuffle(10000).batch(BATCH_SIZE)\n",
        "eval_ds = read_bigquery(EVAL_TABLE_ID).batch(BATCH_SIZE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S_iXxNdSYsDO"
      },
      "source": [
        "##Define feature columns"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UClHDwcyhFky"
      },
      "outputs": [],
      "source": [
        "def get_categorical_feature_values(column):\n",
        "  query = 'SELECT DISTINCT TRIM({}) FROM `{}`.{}.{}'.format(column, PROJECT_ID, DATASET_ID, TRAINING_TABLE_ID)\n",
        "  client = bigquery.Client(project=PROJECT_ID)\n",
        "  dataset_ref = client.dataset(DATASET_ID)\n",
        "  job_config = bigquery.QueryJobConfig()\n",
        "  query_job = client.query(query, job_config=job_config)\n",
        "  result = query_job.to_dataframe()\n",
        "  return result.values[:,0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h9aAs1ZAtQr-"
      },
      "outputs": [],
      "source": [
        "from tensorflow import feature_column\n",
        "\n",
        "feature_columns = []\n",
        "\n",
        "# numeric cols\n",
        "for header in ['capital_gain', 'capital_loss', 'hours_per_week']:\n",
        "  feature_columns.append(feature_column.numeric_column(header))\n",
        "\n",
        "# categorical cols\n",
        "for header in ['workclass', 'marital_status', 'occupation', 'relationship',\n",
        "               'race', 'native_country', 'education']:\n",
        "  categorical_feature = feature_column.categorical_column_with_vocabulary_list(\n",
        "        header, get_categorical_feature_values(header))\n",
        "  categorical_feature_one_hot = feature_column.indicator_column(categorical_feature)\n",
        "  feature_columns.append(categorical_feature_one_hot)\n",
        "\n",
        "# bucketized cols\n",
        "age = feature_column.numeric_column('age')\n",
        "age_buckets = feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])\n",
        "feature_columns.append(age_buckets)\n",
        "\n",
        "feature_layer = tf.keras.layers.DenseFeatures(feature_columns)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HfO0bhXXY3GQ"
      },
      "source": [
        "##Build and train model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GnKZKOQX7Qwx"
      },
      "source": [
        "Build model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Sm-UsB5_zvt0"
      },
      "outputs": [],
      "source": [
        "Dense = tf.keras.layers.Dense\n",
        "model = tf.keras.Sequential(\n",
        "  [\n",
        "    feature_layer,\n",
        "      Dense(100, activation=tf.nn.relu, kernel_initializer='uniform'),\n",
        "      Dense(75, activation=tf.nn.relu),\n",
        "      Dense(50, activation=tf.nn.relu),\n",
        "      Dense(25, activation=tf.nn.relu),\n",
        "      Dense(1, activation=tf.nn.sigmoid)\n",
        "  ])\n",
        "\n",
        "# Compile Keras model\n",
        "model.compile(\n",
        "    loss='binary_crossentropy', \n",
        "    metrics=['accuracy'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f8bSDfQd7T1n"
      },
      "source": [
        "Train model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gPKrlFCN1y00"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "WARNING:tensorflow:Layer sequential is casting an input tensor from dtype float64 to the layer's dtype of float32, which is new behavior in TensorFlow 2.  The layer has dtype float32 because it's dtype defaults to floatx.\n",
            "\n",
            "If you intended to run this layer in float32, you can safely ignore this warning. If in doubt, this warning is likely only an issue if you are porting a TensorFlow 1.X model to TensorFlow 2.\n",
            "\n",
            "To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.\n",
            "\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4276: IndicatorColumn._variable_shape (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
            "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/feature_column/feature_column_v2.py:4331: VocabularyListCategoricalColumn._num_buckets (from tensorflow.python.feature_column.feature_column_v2) is deprecated and will be removed in a future version.\n",
            "Instructions for updating:\n",
            "The old _FeatureColumn APIs are being deprecated. Please use the new FeatureColumn APIs instead.\n",
            "Epoch 1/5\n",
            "1018/1018 [==============================] - 17s 17ms/step - loss: 0.5985 - accuracy: 0.8105\n",
            "Epoch 2/5\n",
            "1018/1018 [==============================] - 10s 10ms/step - loss: 0.3670 - accuracy: 0.8324\n",
            "Epoch 3/5\n",
            "1018/1018 [==============================] - 11s 10ms/step - loss: 0.3487 - accuracy: 0.8393\n",
            "Epoch 4/5\n",
            "1018/1018 [==============================] - 11s 10ms/step - loss: 0.3398 - accuracy: 0.8435\n",
            "Epoch 5/5\n",
            "1018/1018 [==============================] - 11s 11ms/step - loss: 0.3377 - accuracy: 0.8455\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "<tensorflow.python.keras.callbacks.History at 0x7f978f5b91d0>"
            ]
          },
          "execution_count": 17,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "model.fit(training_ds, epochs=5)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UgKKliXRZG_G"
      },
      "source": [
        "##Evaluate model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SgNd5DdU7TEW"
      },
      "source": [
        "Evaluate model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8eGHVkmI5LBT"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "509/509 [==============================] - 8s 15ms/step - loss: 0.3338 - accuracy: 0.8398\n",
            "Accuracy 0.8398452\n"
          ]
        }
      ],
      "source": [
        "loss, accuracy = model.evaluate(eval_ds)\n",
        "print(\"Accuracy\", accuracy)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qBIWoGMP7bj1"
      },
      "source": [
        "Evaluate a couple of random samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aMou1t1xngXP"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[0.5541261],\n",
              "       [0.6209938]], dtype=float32)"
            ]
          },
          "execution_count": 19,
          "metadata": {
            "tags": []
          },
          "output_type": "execute_result"
        }
      ],
      "source": [
        "sample_x = {\n",
        "    'age' : np.array([56, 36]), \n",
        "    'workclass': np.array(['Local-gov', 'Private']), \n",
        "    'education': np.array(['Bachelors', 'Bachelors']), \n",
        "    'marital_status': np.array(['Married-civ-spouse', 'Married-civ-spouse']), \n",
        "    'occupation': np.array(['Tech-support', 'Other-service']), \n",
        "    'relationship': np.array(['Husband', 'Husband']), \n",
        "    'race': np.array(['White', 'Black']), \n",
        "    'gender': np.array(['Male', 'Male']), \n",
        "    'capital_gain': np.array([0, 7298]), \n",
        "    'capital_loss': np.array([0, 0]), \n",
        "    'hours_per_week': np.array([40, 36]), \n",
        "    'native_country': np.array(['United-States', 'United-States'])\n",
        "  }\n",
        "\n",
        "model.predict(sample_x)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UhNtHfuxCGVy"
      },
      "source": [
        "## Resources\n",
        "\n",
        "* [Google Cloud BigQuery Overview](https://github.com/tensorflow/io/blob/master/tensorflow_io/bigquery/README.md)\n",
        "* [Training and prediction with Keras in AI Platform](https://colab.sandbox.google.com/github/GoogleCloudPlatform/cloudml-samples/blob/master/notebooks/tensorflow/getting-started-keras.ipynb)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "Tce3stUlHN0L"
      ],
      "name": "bigquery.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
